import numpy as np
import pandas as pd
import scipy.spatial.distance as distances
from scipy.stats import entropy
import torch


def align_dataframes(df1, df2):
    """
    Aligns two DataFrames to have the same set of columns by discarding the columns
    that are not common to both.

    Parameters:
    - df1: pandas DataFrame
    - df2: pandas DataFrame

    Returns:
    - Tuple of pandas DataFrames (df1_aligned, df2_aligned) with the same columns.
    """
    # Find common columns
    common_columns = df1.columns.intersection(df2.columns)
    
    # Select only the common columns for both DataFrames
    df1_aligned = df1[common_columns]
    df2_aligned = df2[common_columns]
    
    return df1_aligned, df2_aligned

def compute_kernel(x, y, kernel_type="rbf", sigma=None):
    if kernel_type == "rbf":
        if sigma is None:
            sigma = torch.median(torch.pdist(x)) + torch.median(torch.pdist(y))
        dist = torch.cdist(x, y, p=2)
        return torch.exp(-(dist**2) / (2 * sigma**2))
    else:
        raise ValueError(f"Unsupported kernel type: {kernel_type}")


def maximum_mean_discrepancy(x, y, kernel_type="rbf", sigma=None):
    """
    Compute the Maximum Mean Discrepancy (MMD) between two sets of samples x and y.

    Args:
        x (Tensor): A PyTorch tensor of shape (n_x, d), where n_x is the number of samples in x and d is the dimension.
        y (Tensor): A PyTorch tensor of shape (n_y, d), where n_y is the number of samples in y and d is the dimension.
        kernel_type (str): The type of kernel to use. Currently, only 'rbf' (Radial Basis Function) is supported.
        sigma (float, optional): The bandwidth parameter for the RBF kernel. If None, it will be estimated using the median heuristic.

    Returns:
        float: The MMD value between x and y.
    """
    k_xx = compute_kernel(x, x, kernel_type, sigma)
    k_yy = compute_kernel(y, y, kernel_type, sigma)
    k_xy = compute_kernel(x, y, kernel_type, sigma)

    mmd = torch.mean(k_xx) + torch.mean(k_yy) - 2 * torch.mean(k_xy)
    return mmd

def distrib_mean(distribution: pd.DataFrame):
    values = np.double(distribution.index)
    probs = np.double(distribution.values)
    return sum(values * probs)

def get_ground_truth_ate(ground_truth: pd.DataFrame, treatment, outcome) -> float:
    treated_dist = ground_truth["Treated Interventional Distribution"]
    control_dist = ground_truth["Control Interventional Distribution"]

    ate = distrib_mean(treated_dist) - distrib_mean(control_dist)
    return ate


def get_ground_truth_cate(
    ground_truth: pd.DataFrame, treatment, outcome, evidence
) -> pd.DataFrame:
    treated_dist = ground_truth["Treated Conditional Interventional Distribution"]
    control_dist = ground_truth["Control Conditional Interventional Distribution"]
    cate = distrib_mean(treated_dist) - distrib_mean(control_dist)
    return cate


def seed_metrics(
    results: pd.DataFrame, gt_distributions: pd.DataFrame, treatment, outcome, evidence, treated_ground_truth
) -> pd.DataFrame:
    """
    Calculate various metrics for evaluating the performance of causal inference models.

    Args:
        results (pd.DataFrame): The results of the causal inference model.
        ground_truth (pd.DataFrame): The ground truth data for comparison.
        treatment: The name of the treatment variable.
        outcome: The name of the outcome variable.
        evidence: The name of the evidence variable.

    Returns:
        pd.DataFrame: A DataFrame containing the calculated metrics.

    Raises:
        ValueError: If the treatment, outcome, or evidence variables are not found in the data.

    Examples:
        >>> results = pd.DataFrame(...)
        >>> ground_truth = pd.DataFrame(...)
        >>> metrics = seed_metrics(results, ground_truth, 'treatment', 'outcome', 'evidence')
    """

    metrics_dict = {}

    if "ATE" in results and results["ATE"] is not None:
        ate = np.array(results["ATE"]).reshape(1,)
        ground_truth_ate = np.array(get_ground_truth_ate(gt_distributions, treatment, outcome)).reshape(1,)

        metrics_dict.update(
            {
                # ATE Metrics
                "ATE MSE": distances.euclidean(ate, ground_truth_ate),
                "ATE MAE": np.abs(ate-ground_truth_ate),
                "Ground-truth ATE": ground_truth_ate,
            }
        )

    if "CATE" in results and results["CATE"] is not None:
        cate = np.array(results["CATE"]).reshape(1,)
        ground_truth_cate = np.array(get_ground_truth_cate(
            gt_distributions, treatment, outcome, evidence
        )).reshape(1,)

        metrics_dict.update(
            {
                # CATE Metrics
                "CATE MSE": distances.euclidean(cate, ground_truth_cate),
                "CATE MAE": np.abs(cate-ground_truth_cate),
                "Ground-truth CATE": ground_truth_cate,
            }
        )

    if "Interventional Distribution" in results and results["Interventional Distribution"] is not None:
        treated_int_distribution = np.array(results["Interventional Distribution"])
        ground_truth_treated_int_distribution = gt_distributions["Treated Interventional Distribution"]
        ground_truth_treated_int_distribution = np.array(ground_truth_treated_int_distribution)
        metrics_dict.update(
            {
                # Treated interventional distribution metrics
                "Jensen-Shannon": distances.jensenshannon(
                    treated_int_distribution, ground_truth_treated_int_distribution
                ),
                "Kullback-Leibler": entropy(
                    treated_int_distribution, ground_truth_treated_int_distribution
                ),
                "MSE": distances.euclidean(
                    treated_int_distribution, ground_truth_treated_int_distribution
                ),
                "MAE": np.sum(np.abs(
                    treated_int_distribution - ground_truth_treated_int_distribution
                ))
            }
        )

    if "Interventional Samples" in results and results["Interventional Samples"] is not None:
        int_samples = results["Interventional Samples"]
        treated_data = treated_ground_truth["treated_data"]
        int_samples, treated_data = align_dataframes(int_samples, treated_data)

        int_samples.replace([np.inf, -np.inf], np.nan, inplace=True)  # Replace inf with NaN
        int_samples.dropna(inplace=True)
        treated_data.replace([np.inf, -np.inf], np.nan, inplace=True)  # Replace inf with NaN
        treated_data.dropna(inplace=True)
        
        treated_int_samples = torch.from_numpy(int_samples.to_numpy()).float()
        treated_gt = torch.from_numpy(treated_data.to_numpy()).float()

        # The whole dataset is too big, it would require too much memory
        treated_gt = treated_gt[:10000,:]
        treated_int_samples = treated_int_samples[:10000,:]

        metrics_dict.update(
            {
                # Treated interventional samples metrics
                "MMD": maximum_mean_discrepancy(
                    treated_int_samples, treated_gt
                ).item(),
            }
        )
    
    if "Control Interventional Distribution" in results and results["Control Interventional Distribution"] is not None:
        control_int_distribution = np.array(results["Control Interventional Distribution"])
        ground_truth_control_int_distribution = gt_distributions["Control Interventional Distribution"]
        ground_truth_control_int_distribution = np.array(ground_truth_control_int_distribution)
        metrics_dict.update(
            {
                # Control interventional distribution metrics
                "Control Jensen-Shannon": distances.jensenshannon(
                    control_int_distribution, ground_truth_control_int_distribution
                ),
                "Control Kullback-Leibler": entropy(
                    control_int_distribution, ground_truth_control_int_distribution
                ),
                "Control MSE": distances.euclidean(
                    control_int_distribution, ground_truth_control_int_distribution
                ),
                "Control MAE": np.sum(np.abs(
                    control_int_distribution - ground_truth_control_int_distribution
                ))
            }
        )

    if "Conditional Control Distribution" in results and results["Conditional Control Distribution"] is not None:
        control_int_distribution = np.array(results["Conditional Control Distribution"])
        ground_truth_control_int_distribution = gt_distributions["Control Conditional Interventional Distribution"]
        ground_truth_control_int_distribution = np.array(ground_truth_control_int_distribution)
        metrics_dict.update(
            {
                # Control interventional distribution metrics
                "Conditional Control Jensen-Shannon": distances.jensenshannon(
                    control_int_distribution, ground_truth_control_int_distribution
                ),
                "Conditional Control Kullback-Leibler": entropy(
                    control_int_distribution, ground_truth_control_int_distribution
                ),
                "Conditional Control MSE": distances.euclidean(
                    control_int_distribution, ground_truth_control_int_distribution
                ),
                "Conditional Control MAE": np.sum(np.abs(
                    control_int_distribution - ground_truth_control_int_distribution
                ))
            })

    if "Conditional Interventional Samples" in results and results["Conditional Interventional Samples"] is not None:
        int_samples = results["Conditional Interventional Samples"]
        treated_data = treated_ground_truth["treated_data"]
        int_samples, treated_data = align_dataframes(int_samples, treated_data)
        
        int_samples.replace([np.inf, -np.inf], np.nan, inplace=True)  # Replace inf with NaN
        int_samples.dropna(inplace=True)
        conditional_treated_int_samples = torch.from_numpy(results["Conditional Interventional Samples"].to_numpy()).float()
        treated_gt = treated_ground_truth["treated_data"]

        for col in evidence:
            ev_value = evidence[col]
            mask = treated_gt[col] == ev_value
            treated_gt = treated_gt[mask]

        # Empirically Condition the ground truth samples
        conditional_ground_truth = torch.from_numpy(treated_gt.to_numpy()).float()

        # The whole dataset is too big, it would require too much memory
        conditional_ground_truth = conditional_ground_truth[:10000,:]
        conditional_treated_int_samples = conditional_treated_int_samples[:10000,:]

        metrics_dict.update(
            {
                # Conditional Treated interventional samples metrics
                "Conditional MMD": maximum_mean_discrepancy(
                    conditional_treated_int_samples, conditional_ground_truth
                ).item(),
            })

    if "Conditional Interventional Distribution" in results and results["Conditional Interventional Distribution"] is not None:
        conditional_treated_int_distribution = np.array(results[
            "Conditional Interventional Distribution"
        ])
        ground_truth_cond_treated_int_distribution = gt_distributions["Treated Conditional Interventional Distribution"]
        ground_truth_cond_treated_int_distribution = np.array(ground_truth_cond_treated_int_distribution)
        metrics_dict.update(
            {
                # Condtional Treated interventional distribution metrics
                "Conditional Jensen-Shannon": distances.jensenshannon(
                    conditional_treated_int_distribution,
                    ground_truth_cond_treated_int_distribution,
                ),
                "Conditional Kullback-Leibler": entropy(
                    conditional_treated_int_distribution,
                    ground_truth_cond_treated_int_distribution,
                ),
                "Conditional MSE": distances.euclidean(
                    conditional_treated_int_distribution,
                    ground_truth_cond_treated_int_distribution,
                ),
                "Conditional MAE": np.sum(np.abs(
                    conditional_treated_int_distribution -
                    ground_truth_cond_treated_int_distribution,
                ))
            }
        )

    return pd.DataFrame(metrics_dict, index=[0])


def experiment_metrics(
    data_df: pd.DataFrame,
    ground_truth: float,
    dataset_name: str = "",
    method_name: str = "",
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Compute error metrics for each row in the given DataFrame.

    Parameters:
    - data_df (pd.DataFrame): The DataFrame containing the data.
    - ground_truth (float): The ground truth value to compare against.
    - dataset_name (str, optional): The name of the dataset (default: "").
    - method_name (str, optional): The name of the method (default: "").

    Returns:
    - error_metrics_df (pd.DataFrame): The DataFrame containing the computed error metrics.
    - metrics_summary_df (pd.DataFrame): The DataFrame containing the summary statistics of the error metrics.

    The function computes the following error metrics for each row in the DataFrame:
    - Mean Squared Error (MSE)
    - Root Mean Squared Error (RMSE)
    - Mean Absolute Error (MAE)
    - Mean Absolute Percentage Error (MAPE)

    The error metrics are computed by comparing the "Estimate" column in each row of the DataFrame
    with the provided ground truth value.

    The function returns two DataFrames:
    - error_metrics_df: A DataFrame containing the computed error metrics for each row.
    - metrics_summary_df: A DataFrame containing the summary statistics (median, percentiles) of the error metrics.

    Example usage:
    >>> data = pd.DataFrame({"Estimate": [1.2, 2.3, 3.4], "Ground Truth": [1.0, 2.0, 3.0]})
    >>> metrics, summary = experiment_metrics(data, 2.0, dataset_name="Dataset A", method_name="Method 1")
    >>> print(metrics)
       Method    Dataset   MSE  RMSE  MAE  MAPE
    0  Method 1  Dataset A  0.01  0.1  0.2  10.0

    >>> print(summary)
         MSE  RMSE  MAE  MAPE
    count  1.0  1.0  1.0  1.0
    mean   0.01  0.1  0.2  10.0
    std    NaN  NaN  NaN  NaN
    min    0.01  0.1  0.2  10.0
    50%    0.01  0.1  0.2  10.0
    5%     0.01  0.1  0.2  10.0
    25%    0.01  0.1  0.2  10.0
    75%    0.01  0.1  0.2  10.0
    95%    0.01  0.1  0.2  10.0
    """

    # Initialize the DataFrame to store error metrics
    metrics_df = pd.DataFrame(columns=["Method", "Dataset"])

    # Loop through each row in the DataFrame
    for index, row in data_df.iterrows():
        # Compute Mean Squared Error (MSE)
        se = (row["Estimate"] - ground_truth) ** 2
        mae = np.abs(row["Estimate"] - ground_truth)

        metrics_df = pd.concat(
            [
                metrics_df,
                pd.DataFrame(
                    {
                        "Method": [method_name],
                        "Dataset": [dataset_name],
                        "SE": [se],
                        "MAE": [mae],
                    }
                ),
            ]
        )

    # Compute the median and percentiles of the error metrics
    # metrics_summary_df = metrics_df.describe()

    return metrics_df  # , metrics_summary_df
